{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 图像分类网络结构搜索-快速开始\n",
    "\n",
    "该教程以图像分类模型MobileNetV2为例，说明如何在cifar10数据集上快速使用[网络结构搜索接口](../api/nas_api.md)。\n",
    "该示例包含以下步骤：\n",
    "\n",
    "1. 导入依赖\n",
    "2. 初始化SANAS搜索实例\n",
    "3. 构建网络\n",
    "4. 定义输入数据函数\n",
    "5. 定义训练函数\n",
    "6. 定义评估函数\n",
    "7. 启动搜索实验\n",
    "  7.1 获取模型结构\n",
    "  7.2 构造program\n",
    "  7.3 定义输入数据\n",
    "  7.4 训练模型\n",
    "  7.5 评估模型\n",
    "  7.6 回传当前模型的得分\n",
    "8. 完整示例\n",
    "\n",
    "\n",
    "以下章节依次介绍每个步骤的内容。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. 导入依赖\n",
    "请确认已正确安装Paddle，导入需要的依赖包。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import paddle\n",
    "import paddle.fluid as fluid\n",
    "import paddleslim as slim\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. 初始化SANAS搜索实例"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2020-02-07 08:42:37,895-INFO: range table: ([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [7, 5, 8, 6, 2, 5, 8, 6, 2, 5, 8, 6, 2, 5, 10, 6, 2, 5, 10, 6, 2, 5, 12, 6, 2])\n",
      "2020-02-07 08:42:37,897-INFO: ControllerServer - listen on: [10.255.125.38:8339]\n",
      "2020-02-07 08:42:37,899-INFO: Controller Server run...\n"
     ]
    }
   ],
   "source": [
    "sanas = slim.nas.SANAS(configs=[('MobileNetV2Space')], server_addr=(\"\", 8339), save_checkpoint=None)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. 构建网络\n",
    "根据传入的网络结构构造训练program和测试program。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_program(archs):\n",
    "    train_program = fluid.Program()\n",
    "    startup_program = fluid.Program()\n",
    "    with fluid.program_guard(train_program, startup_program):\n",
    "        data = fluid.data(name='data', shape=[None, 3, 32, 32], dtype='float32')\n",
    "        label = fluid.data(name='label', shape=[None, 1], dtype='int64')\n",
    "        output = archs(data)\n",
    "        output = fluid.layers.fc(input=output, size=10)\n",
    "\n",
    "        softmax_out = fluid.layers.softmax(input=output, use_cudnn=False)\n",
    "        cost = fluid.layers.cross_entropy(input=softmax_out, label=label)\n",
    "        avg_cost = fluid.layers.mean(cost)\n",
    "        acc_top1 = fluid.layers.accuracy(input=softmax_out, label=label, k=1)\n",
    "        acc_top5 = fluid.layers.accuracy(input=softmax_out, label=label, k=5)\n",
    "        test_program = fluid.default_main_program().clone(for_test=True)\n",
    "\n",
    "        optimizer = fluid.optimizer.Adam(learning_rate=0.1)\n",
    "        optimizer.minimize(avg_cost)\n",
    "\n",
    "        place = fluid.CPUPlace()\n",
    "        exe = fluid.Executor(place)\n",
    "        exe.run(startup_program)\n",
    "    return exe, train_program, test_program, (data, label), avg_cost, acc_top1, acc_top5"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. 定义输入数据函数\n",
    "使用的数据集为cifar10，paddle框架中`paddle.dataset.cifar`包括了cifar数据集的下载和读取，代码如下："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def input_data(inputs):\n",
    "    train_reader = paddle.batch(paddle.reader.shuffle(paddle.dataset.cifar.train10(cycle=False), buf_size=1024),batch_size=256)\n",
    "    train_feeder = fluid.DataFeeder(inputs, fluid.CPUPlace())\n",
    "    eval_reader = paddle.batch(paddle.dataset.cifar.test10(cycle=False), batch_size=256)\n",
    "    eval_feeder = fluid.DataFeeder(inputs, fluid.CPUPlace())\n",
    "    return train_reader, train_feeder, eval_reader, eval_feeder"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. 定义训练函数\n",
    "根据训练program和训练数据进行训练。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def start_train(program, data_reader, data_feeder):\n",
    "    outputs = [avg_cost.name, acc_top1.name, acc_top5.name]\n",
    "    for data in data_reader():\n",
    "        batch_reward = exe.run(program, feed=data_feeder.feed(data), fetch_list = outputs)\n",
    "        print(\"TRAIN: loss: {}, acc1: {}, acc5:{}\".format(batch_reward[0], batch_reward[1], batch_reward[2]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6. 定义评估函数\n",
    "根据评估program和评估数据进行评估。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def start_eval(program, data_reader, data_feeder):\n",
    "    reward = []\n",
    "    outputs = [avg_cost.name, acc_top1.name, acc_top5.name]\n",
    "    for data in data_reader():\n",
    "        batch_reward = exe.run(program, feed=data_feeder.feed(data), fetch_list = outputs)\n",
    "        reward_avg = np.mean(np.array(batch_reward), axis=1)\n",
    "        reward.append(reward_avg)\n",
    "        print(\"TEST: loss: {}, acc1: {}, acc5:{}\".format(batch_reward[0], batch_reward[1], batch_reward[2]))\n",
    "    finally_reward = np.mean(np.array(reward), axis=0)\n",
    "    print(\"FINAL TEST: avg_cost: {}, acc1: {}, acc5: {}\".format(finally_reward[0], finally_reward[1], finally_reward[2]))\n",
    "    return finally_reward"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 7. 启动搜索实验\n",
    "以下步骤拆解说明了如何获得当前模型结构以及获得当前模型结构之后应该有的步骤，如果想要看如何启动搜索实验的完整示例可以看步骤9。\n",
    "\n",
    "### 7.1 获取模型结构\n",
    "调用`next_archs()`函数获取到下一个模型结构。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2020-02-07 08:42:45,035-INFO: current tokens: [4, 4, 5, 1, 0, 4, 4, 2, 0, 4, 4, 3, 0, 4, 5, 2, 0, 4, 7, 2, 0, 4, 9, 0, 0]\n"
     ]
    }
   ],
   "source": [
    "archs = sanas.next_archs()[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 7.2 构造program"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "exe, train_program, eval_program, inputs, avg_cost, acc_top1, acc_top5 = build_program(archs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 7.3 定义输入数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_reader, train_feeder, eval_reader, eval_feeder = input_data(inputs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 7.4 训练模型\n",
    "据上面得到的训练program和评估数据启动训练。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TRAIN: loss: [2.7999306], acc1: [0.1015625], acc5:[0.44140625]\n"
     ]
    }
   ],
   "source": [
    "start_train(train_program, train_reader, train_feeder)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 7.5 评估模型\n",
    "根据上面得到的评估program和评估数据启动评估。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TEST: loss: [49.99942], acc1: [0.078125], acc5:[0.46484375]\n",
      "FINAL TEST: avg_cost: 49.999420166, acc1: 0.078125, acc5: 0.46484375\n"
     ]
    }
   ],
   "source": [
    "finally_reward = start_eval(eval_program, eval_reader, eval_feeder)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 7.6 回传当前模型的得分"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2020-02-07 08:44:26,774-INFO: Controller - iter: 1; best_reward: 0.078125, best tokens: [4, 4, 5, 1, 0, 4, 4, 2, 0, 4, 4, 3, 0, 4, 5, 2, 0, 4, 7, 2, 0, 4, 9, 0, 0], current_reward: 0.078125; current tokens: [4, 4, 5, 1, 0, 4, 4, 2, 0, 4, 4, 3, 0, 4, 5, 2, 0, 4, 7, 2, 0, 4, 9, 0, 0]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sanas.reward(float(finally_reward[1]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 8. 完整示例\n",
    "以下是一个完整的搜索实验示例，示例中使用FLOPs作为约束条件，搜索实验一共搜索3个step，表示搜索到3个满足条件的模型结构进行训练，每搜>索到一个网络结构训练7个epoch。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2020-02-07 08:45:06,927-INFO: current tokens: [4, 4, 5, 1, 0, 4, 4, 2, 0, 4, 4, 3, 1, 4, 5, 2, 0, 4, 7, 2, 0, 4, 9, 0, 0]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TRAIN: loss: [2.6932292], acc1: [0.08203125], acc5:[0.51953125]\n",
      "TRAIN: loss: [42.387478], acc1: [0.078125], acc5:[0.47265625]\n"
     ]
    }
   ],
   "source": [
    "for step in range(3):\n",
    "    archs = sanas.next_archs()[0]\n",
    "    exe, train_program, eval_progarm, inputs, avg_cost, acc_top1, acc_top5 = build_program(archs)\n",
    "    train_reader, train_feeder, eval_reader, eval_feeder = input_data(inputs)\n",
    "\n",
    "    current_flops = slim.analysis.flops(train_program)\n",
    "    if current_flops > 321208544:\n",
    "        continue\n",
    "\n",
    "    for epoch in range(7):\n",
    "        start_train(train_program, train_reader, train_feeder)\n",
    "\n",
    "    finally_reward = start_eval(eval_program, eval_reader, eval_feeder)\n",
    "\n",
    "    sanas.reward(float(finally_reward[1]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
